Integrated Gradients vs Activation Patching in GPT2-Small

Background, motivation and set up

Objective: Compare attributions using integrated gradients and activation patching, and investigate the discrepancies between the two methods.

Motivation:

  • Understand when and why do IG and AP disagree: e.g. methodological limitations, or suitability to model tasks, etc.
  • Investigate if discrepancies help uncover different hidden model behaviours
  • Understand when and why linear approximations to activation patching fail
  • Investigate limitations of using activation patching for evaluations: if results are different because of other unknown factors (not just because the method evaluated is “incorrect”)

Set-up:

We load the transformer model GPT2-Small, which has 12 layers, 12 attention heads per layer, embedding size 768 and 4 x 768 = 3,072 neurons in each feed-forward layer. We use GPT2-Small because 1) it is a relatively small transformer model which has comparable behaviour to larger SOTA models, and 2) there is a lot of interpretability literature which focuses on circuits in this model.

Code
import torch
import numpy as np

from captum.attr import LayerIntegratedGradients

from transformer_lens.utils import get_act_name, get_device
from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint

import seaborn as sns
import matplotlib.pyplot as plt
Code
torch.set_grad_enabled(False)

# device = get_device()
device = torch.device("cpu")
model = HookedTransformer.from_pretrained("gpt2-small", device=device)
Loaded pretrained model gpt2-small into HookedTransformer

Attribution for GPT2-Small

We scale up our earlier experiments to implement integrated gradients and activation patching on a larger transformer model. We use the same counterfactual inputs, based on the Indirect Object Identification task.

Code
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"

clean_input = model.to_tokens(clean_prompt)
corrupted_input = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer=" John", incorrect_answer=" Mary"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

clean_logits, clean_cache = model.run_with_cache(clean_input)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_input)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
Clean logit difference: 4.276
Corrupted logit difference: -2.738

Integrated Gradients

Code
def run_from_layer_fn(x, original_input, prev_layer, reset_hooks_end=True):
    # Force the layer before the target layer to output the given values, i.e. pass the given input into the target layer
    # original_input value does not matter; useful to keep shapes nice, but its activations will be overwritten
    
    def fwd_hook(act, hook):
        x.requires_grad_(True)
        return x
    
    logits = model.run_with_hooks(
        original_input,
        fwd_hooks=[(prev_layer.name, fwd_hook)],
        reset_hooks_end=reset_hooks_end
    )
    logit_diff = logits_to_logit_diff(logits).unsqueeze(0)
    return logit_diff

def compute_layer_to_output_attributions(original_input, layer_input, layer_baseline, target_layer, prev_layer):
    # Take the model starting from the target layer
    forward_fn = lambda x: run_from_layer_fn(x, original_input, prev_layer)
    # Attribute to the target_layer's output
    ig_embed = LayerIntegratedGradients(forward_fn, target_layer, multiply_by_inputs=True)
    attributions, approximation_error = ig_embed.attribute(inputs=layer_input,
                                                    baselines=layer_baseline, 
                                                    attribute_to_layer_input=False,
                                                    return_convergence_delta=True)
    print(f"\nError (delta) for {target_layer.name} attribution: {approximation_error.item()}")
    return attributions
Code
mlp_ig_zero_results = torch.load("mlp_ig_zero_results.pt")
attn_ig_zero_results = torch.load("attn_ig_zero_results.pt")
Code
# Gradient attribution using the zero baseline, as originally recommended
mlp_ig_zero_results = torch.zeros(model.cfg.n_layers, model.cfg.d_mlp)
attn_ig_zero_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)

# Calculate integrated gradients for each layer
for layer in range(model.cfg.n_layers):
    # Gradient attribution on heads
    hook_name = get_act_name("result", layer)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("z", layer)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_clean_input = clean_cache[prev_layer_hook]
    layer_corrupt_input = torch.zeros_like(layer_clean_input)

    attributions = compute_layer_to_output_attributions(clean_input, layer_corrupt_input, layer_clean_input, target_layer, prev_layer) # shape [1, seq_len, d_head, d_model]
    # Calculate attribution score based on mean over each embedding, for each token
    print(attributions.shape)
    per_token_score = attributions.mean(dim=3)
    score = per_token_score.mean(dim=1)
    attn_ig_zero_results[layer] = score

    # Gradient attribution on MLP neurons
    hook_name = get_act_name("post", layer)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("mlp_in", layer)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_clean_input = clean_cache[prev_layer_hook]
    layer_corrupt_input = torch.zeros_like(layer_clean_input)
    
    attributions = compute_layer_to_output_attributions(clean_input, layer_corrupt_input, layer_clean_input, target_layer, prev_layer) # shape [1, seq_len, d_model]
    print(attributions.shape)
    score = attributions.mean(dim=1)
    mlp_ig_zero_results[layer] = score

torch.save(mlp_ig_zero_results, "mlp_ig_zero_results.pt")
torch.save(attn_ig_zero_results, "attn_ig_zero_results.pt")

Error (delta) for blocks.0.attn.hook_result attribution: 1.5820902585983276
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.0.mlp.hook_post attribution: 6.293460845947266
torch.Size([1, 17, 3072])

Error (delta) for blocks.1.attn.hook_result attribution: 0.07838684320449829
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.1.mlp.hook_post attribution: 0.66867595911026
torch.Size([1, 17, 3072])

Error (delta) for blocks.2.attn.hook_result attribution: 0.22218307852745056
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.2.mlp.hook_post attribution: 1.6563737392425537
torch.Size([1, 17, 3072])

Error (delta) for blocks.3.attn.hook_result attribution: 0.7106841206550598
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.3.mlp.hook_post attribution: 3.3847177028656006
torch.Size([1, 17, 3072])

Error (delta) for blocks.4.attn.hook_result attribution: 1.0539896488189697
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.4.mlp.hook_post attribution: 0.9486498832702637
torch.Size([1, 17, 3072])

Error (delta) for blocks.5.attn.hook_result attribution: 1.1305818557739258
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.5.mlp.hook_post attribution: 3.122290849685669
torch.Size([1, 17, 3072])

Error (delta) for blocks.6.attn.hook_result attribution: 0.8233247995376587
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.6.mlp.hook_post attribution: 2.46148681640625
torch.Size([1, 17, 3072])

Error (delta) for blocks.7.attn.hook_result attribution: 1.1756279468536377
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.7.mlp.hook_post attribution: 0.29078352451324463
torch.Size([1, 17, 3072])

Error (delta) for blocks.8.attn.hook_result attribution: 1.6124217510223389
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.8.mlp.hook_post attribution: -0.8814370632171631
torch.Size([1, 17, 3072])

Error (delta) for blocks.9.attn.hook_result attribution: -0.9382901191711426
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.9.mlp.hook_post attribution: 0.8226296305656433
torch.Size([1, 17, 3072])

Error (delta) for blocks.10.attn.hook_result attribution: -0.5772985219955444
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.10.mlp.hook_post attribution: 0.05703079327940941
torch.Size([1, 17, 3072])

Error (delta) for blocks.11.attn.hook_result attribution: -0.18520982563495636
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.11.mlp.hook_post attribution: -0.8039997816085815
torch.Size([1, 17, 3072])
Code
bound = max(torch.max(mlp_ig_zero_results), abs(torch.min(mlp_ig_zero_results)))

plt.figure(figsize=(75, 10))
plt.imshow(mlp_ig_zero_results.detach(), cmap='RdBu', vmin=-bound, vmax=bound, aspect="auto")
plt.title("MLP Neuron Gradient Attribution (Integrated Gradients)")
plt.xticks(np.arange(0, model.cfg.d_mlp, 250))
plt.xlabel("Neuron Index")
plt.yticks(list(range(model.cfg.n_layers)))
plt.ylabel("Layer")
plt.colorbar()
plt.show()

Code
bound = max(torch.max(attn_ig_zero_results), abs(torch.min(attn_ig_zero_results)))

plt.figure(figsize=(10, 5))
plt.imshow(attn_ig_zero_results.detach(), cmap='RdBu', vmin=-bound, vmax=bound)
plt.title("Attention Head Gradient Attribution (Integrated Gradients)")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar()
plt.show()

Activation Patching

Code
def patch_neuron_hook(activations: torch.Tensor, hook: HookPoint, cache: ActivationCache, neuron_idx: int):
    # Replace the activations for the target neuron with activations from the cached run.
    cached_activations = cache[hook.name]
    activations[:, :, neuron_idx] = cached_activations[:, :, neuron_idx]
    return activations

def patch_attn_hook(activations: torch.Tensor, hook: HookPoint, cache: ActivationCache, head_idx: int):
    # Replace the activations for the target attention head with activations from the cached run.
    cached_activations = cache[hook.name]
    activations[:, :, head_idx, :] = cached_activations[:, :, head_idx, :]
    return activations

baseline_diff = (clean_logit_diff - corrupted_logit_diff).item()
Code
mlp_patch_results = torch.load("mlp_patch_results.pt")
attn_patch_results = torch.load("attn_patch_results.pt")
Code
class StopExecution(Exception):
    def _render_traceback_(self):
        return []
    
# Check if we have run activation patching already (expensive)
try:
    mlp_patch_results = torch.load("mlp_patch_results.pt")
    attn_patch_results = torch.load("attn_patch_results.pt")
    raise StopExecution
except FileNotFoundError:
    mlp_patch_results = torch.zeros(model.cfg.n_layers, model.cfg.d_mlp)
    attn_patch_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)

    for layer in range(model.cfg.n_layers):
        # Activation patching on heads
        for head in range(model.cfg.n_heads):
            hook_name = get_act_name("result", layer)
            temp_hook = lambda act, hook: patch_attn_hook(act, hook, corrupted_cache, head)

            with model.hooks(fwd_hooks=[(hook_name, temp_hook)]):
                patched_logits = model(clean_input)

            patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
            # Normalise result by clean and corrupted logit difference
            attn_patch_results[layer, head] = (patched_logit_diff - clean_logit_diff) / baseline_diff

        # Activation patching on MLP neurons
        for neuron in range(model.cfg.d_mlp):
            hook_name = get_act_name("post", layer)
            temp_hook = lambda act, hook: patch_neuron_hook(act, hook, corrupted_cache, neuron)
            
            with model.hooks(fwd_hooks=[(hook_name, temp_hook)]):
                patched_logits = model(clean_input)

            patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
            # Normalise result by clean and corrupted logit difference
            mlp_patch_results[layer, neuron] = (patched_logit_diff - clean_logit_diff) / baseline_diff
    
    torch.save(mlp_patch_results, "mlp_patch_results.pt")
    torch.save(attn_patch_results, "attn_patch_results.pt")
Code
bound = max(torch.max(mlp_patch_results), abs(torch.min(mlp_patch_results)))

plt.figure(figsize=(75, 10))
plt.imshow(mlp_patch_results.detach(), cmap='RdBu', vmin=-bound, vmax=bound, aspect="auto")
plt.title("MLP Neuron Gradient Attribution (Activation Patching)")
plt.xticks(np.arange(0, model.cfg.d_mlp, 250))
plt.xlabel("Neuron Index")
plt.yticks(list(range(model.cfg.n_layers)))
plt.ylabel("Layer")
plt.colorbar()
plt.show()

Code
bound = max(torch.max(attn_patch_results), abs(torch.min(attn_patch_results)))

plt.figure(figsize=(10, 5))
plt.imshow(attn_patch_results.detach(), cmap='RdBu', vmin=-bound, vmax=bound)
plt.title("Attention Head Gradient Attribution (Activation Patching)")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar()
plt.show()

Analysis of initial attribution methods

Code
# Plot the attribution scores against each other. Correlation: y = x.

x = mlp_ig_zero_results.flatten().numpy()
y = mlp_patch_results.flatten().numpy()

sns.regplot(x=x, y=y)
plt.xlabel("Integrated Gradients MLP Attribution Scores")
plt.ylabel("Activation Patching MLP Attribution Scores")
plt.show()

print(f"Correlation coefficient between IG and AP attributions for neurons: {np.corrcoef(x, y)[0, 1]}")

x = attn_ig_zero_results.flatten().numpy()
y = attn_patch_results.flatten().numpy()

sns.regplot(x=x, y=y)
plt.xlabel("Integrated Gradients Attention Attribution Scores")
plt.ylabel("Causal Tracing Attention Attribution Scores")
plt.show()

print(f"Correlation coefficient between IG and AP attributions for attention: {np.corrcoef(x, y)[0, 1]}")

Correlation coefficient between IG and AP attributions for neurons: 0.23509090894797813

Correlation coefficient between IG and AP attributions for attention: 0.30637427255128147
Code
def get_top_k_by_abs(data, k):
    _, indices = torch.topk(data.flatten().abs(), k)
    top_k_values = torch.gather(data.flatten(), 0, indices)
    formatted_indices = []
    for idx in indices:
        layer = idx // model.cfg.d_mlp
        neuron_pos = idx % model.cfg.d_mlp
        formatted_indices.append([layer, neuron_pos])
    return torch.tensor(formatted_indices), top_k_values

def get_attributions_above_threshold(data, percentile):
    threshold = torch.min(data) + percentile * (torch.max(data) - torch.min(data))
    masked_data = torch.where(data > threshold, data, 0)
    nonzero_indices = torch.nonzero(masked_data)
    return nonzero_indices, masked_data

top_mlp_ig_zero_indices, top_mlp_ig_zero_results = get_top_k_by_abs(mlp_ig_zero_results, 30)
top_mlp_patch_indices, top_mlp_patch_results = get_top_k_by_abs(mlp_patch_results, 30)

top_mlp_ig_zero_sets = set([tuple(t.tolist()) for t in top_mlp_ig_zero_indices])
top_mlp_patch_sets = set([tuple(t.tolist()) for t in top_mlp_patch_indices])

intersection = top_mlp_ig_zero_sets.intersection(top_mlp_patch_sets)
union = top_mlp_ig_zero_sets.union(top_mlp_patch_sets)
jaccard = len(intersection) / len(union)

print(f"Jaccard score for MLP neurons: {jaccard}")
Jaccard score for MLP neurons: 0.1111111111111111
Code
from sklearn.preprocessing import MaxAbsScaler

mlp_ig_zero_results_1d = mlp_ig_zero_results.flatten().numpy()
mlp_patch_results_1d = mlp_patch_results.flatten().numpy()

# Mean difference plot with scaled data

scaled_mlp_ig_results_1d = MaxAbsScaler().fit_transform(mlp_ig_zero_results_1d.reshape(-1, 1))
scaled_mlp_patch_results_1d = MaxAbsScaler().fit_transform(mlp_patch_results_1d.reshape(-1, 1))

mean = np.mean([scaled_mlp_ig_results_1d, scaled_mlp_patch_results_1d], axis=0)
diff = scaled_mlp_patch_results_1d - scaled_mlp_ig_results_1d
md = np.mean(diff) # Mean of the difference
sd = np.std(diff, axis=0) # Standard deviation of the difference

plt.figure(figsize=(10, 6))
sns.regplot(x=mean, y=diff, fit_reg=True, scatter=True)
plt.axhline(md, color='gray', linestyle='--', label="Mean difference")
plt.axhline(md + 1.96*sd, color='pink', linestyle='--', label="1.96 SD of difference")
plt.axhline(md - 1.96*sd, color='lightblue', linestyle='--', label="-1.96 SD of difference")
plt.xlabel("Mean of attribution scores per neuron")
plt.ylabel("Difference (activation patching - integrated gradients) per neuron")
plt.title("Mean-difference plot of scaled attribution scores from integrated gradients and activation patching")
plt.legend()
plt.show()

Code
from sklearn.preprocessing import MaxAbsScaler

scaled_attn_ig_zero_results = MaxAbsScaler().fit_transform(attn_ig_zero_results)
scaled_attn_patch_results = MaxAbsScaler().fit_transform(attn_patch_results)

diff_attn_results = scaled_attn_ig_zero_results - scaled_attn_patch_results
diff_attn_results_abs = np.abs(scaled_attn_ig_zero_results) - np.abs(scaled_attn_patch_results)

plt.figure(figsize=(10,10))
plt.subplot(1, 2, 1)
plt.imshow(diff_attn_results, cmap="RdBu")
plt.title("Difference in attributions for attention heads")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar(orientation="horizontal")

plt.subplot(1, 2, 2)
plt.imshow(diff_attn_results_abs, cmap="RdBu")
plt.title("Difference in (absolute) attributions for attention heads")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar(orientation="horizontal")
plt.tight_layout()
plt.show()

Comparable baselines

Hypothesis: One possible reason for the discrepancy between patching and IG is that the range of activations tested may be from different distributions.

Both gradient methods rely on counterfactual reasoning. IG computes the integral between some baseline (which produces zero output) and given input, whereas causal tracing computes the logit difference between two counterfactual inputs. If the counterfactuals used are different, then this could cause a discrepancy.

To evaluate this hypothesis, we compute IG and AP on GPT2-Small with the same counterfactual inputs.

Code
mlp_ig_results = torch.load("mlp_ig_results.pt")
attn_ig_results = torch.load("attn_ig_results.pt")
Code
# Gradient attribution for neurons in MLP layers
mlp_ig_results = torch.zeros(model.cfg.n_layers, model.cfg.d_mlp)
# Gradient attribution for attention heads
attn_ig_results = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)

# Calculate integrated gradients for each layer
for layer in range(model.cfg.n_layers):
    # Gradient attribution on heads
    hook_name = get_act_name("result", layer)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("z", layer)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_clean_input = clean_cache[prev_layer_hook]
    layer_corrupt_input = corrupted_cache[prev_layer_hook]

    attributions = compute_layer_to_output_attributions(clean_input, layer_corrupt_input, layer_clean_input, target_layer, prev_layer) # shape [1, seq_len, d_head, d_model]
    # Calculate attribution score based on mean over each embedding, for each token
    print(attributions.shape)
    per_token_score = attributions.mean(dim=3)
    score = per_token_score.mean(dim=1)
    attn_ig_results[layer] = score

    # Gradient attribution on MLP neurons
    hook_name = get_act_name("post", layer)
    target_layer = model.hook_dict[hook_name]
    prev_layer_hook = get_act_name("mlp_in", layer)
    prev_layer = model.hook_dict[prev_layer_hook]

    layer_clean_input = clean_cache[prev_layer_hook]
    layer_corrupt_input = corrupted_cache[prev_layer_hook]
    
    attributions = compute_layer_to_output_attributions(clean_input, layer_corrupt_input, layer_clean_input, target_layer, prev_layer) # shape [1, seq_len, d_model]
    print(attributions.shape)
    score = attributions.mean(dim=1)
    mlp_ig_results[layer] = score

torch.save(mlp_ig_results, "mlp_ig_results.pt")
torch.save(attn_ig_results, "attn_ig_results.pt")

Error (delta) for blocks.0.attn.hook_result attribution: -0.08010423183441162
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.0.mlp.hook_post attribution: 5.367663860321045
torch.Size([1, 17, 3072])

Error (delta) for blocks.1.attn.hook_result attribution: 0.05616430938243866
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.1.mlp.hook_post attribution: -0.10697655379772186
torch.Size([1, 17, 3072])

Error (delta) for blocks.2.attn.hook_result attribution: -0.012761879712343216
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.2.mlp.hook_post attribution: -0.11945552378892899
torch.Size([1, 17, 3072])

Error (delta) for blocks.3.attn.hook_result attribution: 0.2565889358520508
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.3.mlp.hook_post attribution: 0.1360594630241394
torch.Size([1, 17, 3072])

Error (delta) for blocks.4.attn.hook_result attribution: 0.051070213317871094
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.4.mlp.hook_post attribution: -0.08819704502820969
torch.Size([1, 17, 3072])

Error (delta) for blocks.5.attn.hook_result attribution: 0.3684248626232147
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.5.mlp.hook_post attribution: 0.24765437841415405
torch.Size([1, 17, 3072])

Error (delta) for blocks.6.attn.hook_result attribution: 0.3670154809951782
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.6.mlp.hook_post attribution: -0.040538765490055084
torch.Size([1, 17, 3072])

Error (delta) for blocks.7.attn.hook_result attribution: 1.3272550106048584
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.7.mlp.hook_post attribution: 0.45815184712409973
torch.Size([1, 17, 3072])

Error (delta) for blocks.8.attn.hook_result attribution: 2.3821561336517334
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.8.mlp.hook_post attribution: 0.12053033709526062
torch.Size([1, 17, 3072])

Error (delta) for blocks.9.attn.hook_result attribution: 1.4569836854934692
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.9.mlp.hook_post attribution: 0.6440849304199219
torch.Size([1, 17, 3072])

Error (delta) for blocks.10.attn.hook_result attribution: -1.0445181131362915
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.10.mlp.hook_post attribution: 0.45244646072387695
torch.Size([1, 17, 3072])

Error (delta) for blocks.11.attn.hook_result attribution: -1.987703800201416
torch.Size([1, 17, 12, 768])

Error (delta) for blocks.11.mlp.hook_post attribution: 0.400459885597229
torch.Size([1, 17, 3072])
Code
bound = max(torch.max(mlp_ig_results), abs(torch.min(mlp_ig_results)))

plt.figure(figsize=(75, 10))
plt.imshow(mlp_ig_results.detach(), cmap='RdBu', vmin=-bound, vmax=bound, aspect="auto")
plt.title("MLP Neuron Gradient Attribution (Integrated Gradients) with Corrupt Baseline")
plt.xticks(np.arange(0, model.cfg.d_mlp, 250))
plt.xlabel("Neuron Index")
plt.yticks(list(range(model.cfg.n_layers)))
plt.ylabel("Layer")
plt.colorbar()
plt.show()

Code
bound = max(torch.max(attn_ig_results), abs(torch.min(attn_ig_results)))

plt.figure(figsize=(10, 5))
plt.imshow(attn_ig_results.detach(), cmap='RdBu', vmin=-bound, vmax=bound)
plt.title("Attention Head Gradient Attribution (Integrated Gradients) with Corrupt Baseline")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar()
plt.show()

Analysis of comparable baselines

Code
# Plot the attribution scores against each other. Correlation: y = x.

x = mlp_ig_results.flatten().numpy()
y = mlp_patch_results.flatten().numpy()

sns.regplot(x=x, y=y)
plt.xlabel("Integrated Gradients (Corrupt Baseline) MLP Attribution Scores")
plt.ylabel("Activation Patching MLP Attribution Scores")
plt.show()

print(f"Correlation coefficient between IG with corrupted baseline and AP attributions for neurons: {np.corrcoef(x, y)[0, 1]}")

x = attn_ig_results.flatten().numpy()
y = attn_patch_results.flatten().numpy()

sns.regplot(x=x, y=y)
plt.xlabel("Integrated Gradients (Corrupt Baseline) Attention Attribution Scores")
plt.ylabel("Causal Tracing Attention Attribution Scores")
plt.show()

print(f"Correlation coefficient between IG with corrupted baseline and AP attributions for attention: {np.corrcoef(x, y)[0, 1]}")

Correlation coefficient between IG with corrupted baseline and AP attributions for neurons: 0.9852227509307566

Correlation coefficient between IG with corrupted baseline and AP attributions for attention: 0.9547628738711134

The correlation between attribution scores for MLP neurons and attention heads is extremely high! This indicates that, with the same baseline, both methods obtain very similar attribution scores.

Code
def get_top_k_by_abs(data, k):
    _, indices = torch.topk(data.flatten().abs(), k)
    top_k_values = torch.gather(data.flatten(), 0, indices)
    formatted_indices = []
    for idx in indices:
        layer = idx // model.cfg.d_mlp
        neuron_pos = idx % model.cfg.d_mlp
        formatted_indices.append([layer, neuron_pos])
    return torch.tensor(formatted_indices), top_k_values

def get_attributions_above_threshold(data, percentile):
    threshold = torch.min(data) + percentile * (torch.max(data) - torch.min(data))
    masked_data = torch.where(data > threshold, data, 0)
    nonzero_indices = torch.nonzero(masked_data)
    return nonzero_indices, masked_data

top_mlp_ig_indices, top_mlp_ig_results = get_top_k_by_abs(mlp_ig_results, 30)
top_mlp_patch_indices, top_mlp_patch_results = get_top_k_by_abs(mlp_patch_results, 30)

top_mlp_ig_sets = set([tuple(t.tolist()) for t in top_mlp_ig_indices])
top_mlp_patch_sets = set([tuple(t.tolist()) for t in top_mlp_patch_indices])

intersection = top_mlp_ig_sets.intersection(top_mlp_patch_sets)
union = top_mlp_ig_sets.union(top_mlp_patch_sets)
jaccard = len(intersection) / len(union)

print(f"Jaccard score for MLP neurons: {jaccard}")
Jaccard score for MLP neurons: 0.875
Code
from sklearn.preprocessing import MaxAbsScaler

mlp_ig_results_1d = mlp_ig_results.flatten().numpy()
mlp_patch_results_1d = mlp_patch_results.flatten().numpy()

# Mean difference plot with scaled data

scaled_mlp_ig_results_1d = MaxAbsScaler().fit_transform(mlp_ig_results_1d.reshape(-1, 1))
scaled_mlp_patch_results_1d = MaxAbsScaler().fit_transform(mlp_patch_results_1d.reshape(-1, 1))

mean = np.mean([scaled_mlp_ig_results_1d, scaled_mlp_patch_results_1d], axis=0)
diff = scaled_mlp_patch_results_1d - scaled_mlp_ig_results_1d
md = np.mean(diff) # Mean of the difference
sd = np.std(diff, axis=0) # Standard deviation of the difference

plt.figure(figsize=(10, 6))
sns.regplot(x=mean, y=diff, fit_reg=True, scatter=True)
plt.axhline(md, color='gray', linestyle='--', label="Mean difference")
plt.axhline(md + 1.96*sd, color='pink', linestyle='--', label="1.96 SD of difference")
plt.axhline(md - 1.96*sd, color='lightblue', linestyle='--', label="-1.96 SD of difference")
plt.xlabel("Mean of attribution scores per neuron")
plt.ylabel("Difference (activation patching - integrated gradients) per neuron")
plt.title("Mean-difference plot of scaled attribution scores from integrated gradients and activation patching")
plt.legend()
plt.show()

The mean difference plot seems to suggest that there is still some proportional bias. The difference between activation patching scores and integrated gradients scores increases as the attribution score deviates from 0. Integrated gradients seems to estimate more extreme attribution scores than activation patching.

Code
from sklearn.preprocessing import MaxAbsScaler, StandardScaler, MinMaxScaler, RobustScaler

scaled_attn_ig_results = attn_ig_results * 1e5
scaled_attn_patch_results = attn_patch_results

plt.figure(figsize=(10,10))
plt.subplot(2, 2, 1)
plt.imshow(scaled_attn_ig_results, cmap="RdBu", vmin=-0.4, vmax=0.4)
plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))
plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))
plt.colorbar(orientation="horizontal")

plt.figure(figsize=(10,10))
plt.subplot(2, 2, 1)
plt.imshow(scaled_attn_patch_results, cmap="RdBu", vmin=-0.4, vmax=0.4)
plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))
plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))
plt.colorbar(orientation="horizontal")

diff_attn_results = scaled_attn_ig_results - scaled_attn_patch_results
diff_attn_results_abs = np.abs(scaled_attn_ig_results) - np.abs(scaled_attn_patch_results)

plt.figure(figsize=(10,10))
plt.subplot(1, 2, 1)
plt.imshow(diff_attn_results, cmap="RdBu", vmin=-0.2, vmax=0.2)
plt.title("Difference in attributions for attention heads")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar(orientation="horizontal")

plt.subplot(1, 2, 2)
plt.imshow(diff_attn_results_abs, cmap="RdBu", vmin=-0.2, vmax=0.2)
plt.title("Difference in (absolute) attributions for attention heads")

plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))

plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))

plt.colorbar(orientation="horizontal")
plt.tight_layout()
plt.show()

Remaining questions include:

  • Although both methods are aligned when the baselines are the same, this doesn’t mean that they capture the most faithful attribution scores. For instance, if we change the baseline (which is arbitrarily set to some counterfactual value), we could get different components in the circuit. How do we select the best baselines such that faithful circuits are highlighted?

  • There are still some discrepancies in attribution scores, particularly for attention heads. What could be the cause of different attention head attribution scores?

Comparison to IOI circuit

The attention heads highlighted in the original IOI paper seem to correspond with the attention heads highlighted by both methods.

ioi_diagram

Investigation of discrepancies

General-purpose components

Of all the outliers, head (9, 6) is the only one which is strongly highlighted by integrated gradients, but not by activation patching. The other outliers have larger attribution scores assigned by integrated gradients compared to activation patching, but are highlighted by both methods.

Hypothesis: the components which are highlighted only by integrated gradients are important attention heads, which are used generically in both the clean run and corrupted run.

  • They are not detected as strongly by activation patching, which only takes the difference in logits, i.e. highlights components which are needed for the corrupted run, but not the clean run.
  • Suggested by Ferrando and Voita (2024)
Code
import json

class IOIDataset:

    def __init__(self, src_path: str):
        with open(src_path) as f:
            self.data = json.load(f)
        
    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx):
        if isinstance(idx, slice):
            prompts_answers = [(d['prompt'], d['answer']) for d in self.data[idx]]
            return prompts_answers
        return (self.data[idx]['prompt'], self.data[idx]['answer'])

    def to(self, device):
        self.data = self.data.to(device)
        return self
Code
ioi_dataset = IOIDataset("ioi_dataset.json")[:10_000]

Experiment 1: zero ablation

To test this, we can ablate the components which have statistically significant attribution scores outside the limits of agreement. If they affect the performance, this shows that the components are necessary, even though they are not picked up by activation patching.

Code
def evaluate_ioi_performance(ioi_dataset: IOIDataset, model: HookedTransformer):
    num_correct = 0
    num_eval = 0
    for prompt, answer in ioi_dataset:
        if num_eval % 50 == 0:
            print(f"Evaluating prompt {num_eval}")
        outputs = model.generate(input=prompt, max_new_tokens=3, do_sample=False, verbose=False)
        generated_answer = outputs.removeprefix(prompt).split()[0]
        if answer in generated_answer:
            num_correct += 1
        num_eval += 1
    return num_correct / num_eval
Code
# Measure baseline performance of model on IOI task
baseline_performance = evaluate_ioi_performance(ioi_dataset, model)
print(baseline_performance)
Code
# Identify statistically significant outlier components

diff = np.abs(scaled_attn_patch_results - scaled_attn_ig_results)
diff_std = np.std(diff.numpy())

print(f"Standard deviation of differences: {diff_std}")

attn_outliers = []
for layer in range(model.cfg.n_layers):
    for head_idx in range(model.cfg.n_heads):
        if diff[layer, head_idx] > 1.96*diff_std:
            attn_outliers.append((layer, head_idx))

print(attn_outliers)
Standard deviation of differences: 0.020343296229839325
[(7, 3), (8, 10), (9, 6), (9, 9), (10, 6), (11, 10)]
Code
# Ablate components: zero ablation

all_performance_scores = []

for layer, head_idx in attn_outliers:
    attn_hook = get_act_name("result", layer)

    def ablate_hook(activations, hook):
        activations[:, :, head_idx, :] = 0
        return activations

    with model.hooks(fwd_hooks=[(attn_hook, ablate_hook)]):
        performance = evaluate_ioi_performance(ioi_dataset, model)
        all_performance_scores.append(performance)
        print(f"Performance after ablating attention head {(layer, head_idx)}: {performance}")

# TODO: Mean ablation, random ablation
Code
print(baseline_performance)
print(all_performance_scores)
0.7864
[0.6981, 0.7877, 0.8052, 0.8009, 0.7245, 0.9361]
Code
np.save("all_performance_scores.npy", all_performance_scores)
Code
for layer, idx in attn_outliers:
    print(f"Attention head {(layer, idx)}")
    ig_score = attn_ig_results[layer, idx]
    patch_score = attn_patch_results[layer, idx]
    print(f"IG score: {ig_score:.5f}, AP score: {patch_score:.5f}\n")
Attention head (7, 3)
IG score: -0.00000, AP score: -0.10789

Attention head (8, 10)
IG score: -0.00000, AP score: -0.14233

Attention head (9, 6)
IG score: -0.00000, AP score: 0.00889

Attention head (9, 9)
IG score: -0.00000, AP score: -0.11382

Attention head (10, 6)
IG score: -0.00000, AP score: -0.07559

Attention head (11, 10)
IG score: 0.00000, AP score: 0.21538
Code
plt.title("Model performance after zero ablation of attention head outliers")
plt.xlabel("Ablated attention head position")
plt.ylabel("Model performance on IOI tasks")

xs = ["None"] + [str(t) for t in attn_outliers]
ys = [baseline_performance] + all_performance_scores

plt.bar(xs, ys)
plt.show()

Code
# Correlation between difference in attribution scores and difference in performance

performance_differences = []
ig_outlier_scores = []
ap_outlier_scores = []
score_differences = []

for i in range(len(all_performance_scores)):
    performance_differences.append(all_performance_scores[i] - baseline_performance)
    layer, attn_idx = attn_outliers[i]
    ap_outlier_scores.append(attn_patch_results[layer, attn_idx])
    ig_outlier_scores.append(attn_ig_results[layer, attn_idx])
    score_diff = attn_patch_results[layer, attn_idx] - attn_ig_results[layer, attn_idx]
    score_differences.append(score_diff)
Code
sns.regplot(x=score_differences, y=performance_differences)
plt.ylabel("Difference between performance and baseline performance")
plt.xlabel("Difference between patching and IG attribution scores")
plt.show()

print(f"Correlation coefficient between attribution score differences and performance score differences: {np.corrcoef(score_differences, performance_differences)[0, 1]}")

Correlation coefficient between attribution score differences and performance score differences: 0.8298786622776018
Code
# Correlation between attribution scores and performance change

sns.regplot(x=ig_outlier_scores, y=performance_differences)
plt.ylabel("Difference between performance and baseline performance")
plt.xlabel("IG attribution scores")
plt.show()

print(f"Correlation coefficient between IG attribution score and performance score differences: {np.corrcoef(ig_outlier_scores, performance_differences)[0, 1]}")

sns.regplot(x=ap_outlier_scores, y=performance_differences)
plt.ylabel("Difference between performance and baseline performance")
plt.xlabel("Activation Patching attribution scores")
plt.show()

print(f"Correlation coefficient between IG attribution score and performance score differences: {np.corrcoef(ap_outlier_scores, performance_differences)[0, 1]}")

Correlation coefficient between IG attribution score and performance score differences: 0.7296915879554248

Correlation coefficient between IG attribution score and performance score differences: 0.8298774188113335
  • Ablating head (9, 6) does not have a strong effect on the performance.
    • Conclusion: components which are only identified by integrated gradients may not be important for the specific task.
  • Interestingly, ablating heads identified as moderately important by activation patching (e.g. (8, 10), (9, 6), and (9, 9)) do not have significant impact on the performance either.
    • Conclusion: neither method identifies the minimal set of important attention heads.
    • Comparison to original IOI paper: under mean ablation, these heads (and 9.6) are highlighted and impact performance more noticeably.
  • There is not really a strong pattern / correlation between components which have higher attribution scores under IG or AP, and impact on performance.

Experiment 2: Mean ablation

Instead of using zero ablation, we use mean ablation to study the effect of a component’s removal on the model’s performance.

Code
import random

# Get mean activations
model = model.to("cpu")

attn_outlier_hooks = [get_act_name("result", layer_idx) for layer_idx, _ in attn_outliers]

random_prompts = random.sample(ioi_dataset, 100)
prompts_tokens = model.to_tokens([p for p, _ in random_prompts])
_, prompt_cache = model.run_with_cache(prompts_tokens, names_filter=lambda x: x in attn_outlier_hooks)

mean_activations = {}
for key in prompt_cache.keys():
    mean_values_over_prompts = torch.mean(prompt_cache[key], dim=0)
    mean_activations[key] = torch.mean(mean_values_over_prompts, dim=0)
Moving model to device:  cpu
Code
# Ablate components: mean ablation

all_performance_scores_mean_ablation = []

for layer, head_idx in attn_outliers:
    attn_hook = get_act_name("result", layer)

    def ablate_hook(activations, hook):
        mean_hook_acts = mean_activations[hook.name]
        activations[:, :, head_idx, :] = mean_hook_acts[head_idx]
        return activations

    with model.hooks(fwd_hooks=[(attn_hook, ablate_hook)]):
        performance = evaluate_ioi_performance(ioi_dataset, model)
        all_performance_scores_mean_ablation.append(performance)
        print(f"Performance after mean ablating attention head {(layer, head_idx)}: {performance}")
Code
print(all_performance_scores_mean_ablation)

np.save("all_performance_scores_mean_ablation.npy", all_performance_scores_mean_ablation)
[]
Code
plt.title("Model performance after mean ablation of attention head outliers")
plt.xlabel("Ablated attention head position")
plt.ylabel("Model performance on IOI tasks")

xs = ["None"] + [str(t) for t in attn_outliers]
baseline_performance = 0.7864
ys = [baseline_performance] + all_performance_scores_mean_ablation

plt.bar(xs, ys)
plt.show()

Code
# Correlation between difference in attribution scores and difference in performance

mean_performance_differences = []
ig_outlier_scores = []
ap_outlier_scores = []
score_differences = []

for i in range(len(all_performance_scores_mean_ablation)):
    mean_performance_differences.append(all_performance_scores_mean_ablation[i] - baseline_performance)
    layer, attn_idx = attn_outliers[i]
    ap_outlier_scores.append(attn_patch_results[layer, attn_idx])
    ig_outlier_scores.append(attn_ig_results[layer, attn_idx])
    score_diff = attn_patch_results[layer, attn_idx] - attn_ig_results[layer, attn_idx]
    score_differences.append(score_diff)
Code
# Correlation between attribution scores and performance change

sns.regplot(x=ig_outlier_scores, y=mean_performance_differences)
plt.ylabel("Difference between performance and baseline performance under mean ablation")
plt.xlabel("IG attribution scores")
plt.show()

print(f"Correlation coefficient between IG attribution score and performance score difference under mean ablation: {np.corrcoef(ig_outlier_scores, mean_performance_differences)[0, 1]}")

sns.regplot(x=ap_outlier_scores, y=mean_performance_differences)
plt.ylabel("Difference between performance and baseline performance")
plt.xlabel("Activation Patching attribution scores")
plt.show()

print(f"Correlation coefficient between IG attribution score and performance score difference under mean ablation: {np.corrcoef(ap_outlier_scores, mean_performance_differences)[0, 1]}")

Correlation coefficient between IG attribution score and performance score difference under mean ablation: 0.7868078399469188

Correlation coefficient between IG attribution score and performance score difference under mean ablation: 0.8659126942995752
  • Ablating head (9, 6) has a moderate but significant effect on the performance.
    • Significant at p=0.05 (~15%) improvement in performance under mean ablation.
    • Conclusion: components which are only identified by integrated gradients may still be important for the specific task.
  • For some heads identified as important by both methods (e.g. (8, 10) and (9, 9)), mean ablation does not have significant impact on the performance either.
    • Conclusion: neither method identifies the minimal set of important attention heads.
  • There is moderate correlation between components which have higher attribution scores under IG or AP, and impact on performance.

Overestimation and underestimation

From the ablation experiments, we can see that IG assigns higher attribution scores than AP, but some of these attribution scores are overestimated. AP also underestimates the attribution scores for some heads!

  • IG has more true positives, but also more false positives: IG has higher recall, but AP has higher precision.
  • Overall the results between the methods are very similar.
  • What causes false positives in IG?

Hypothesis: outliers overestimated by IG are due to the shape of output curve in between the baseline and inputs to IG.

  • IG calculates change in loss based on integrating gradients between two input values.
  • A high attribution score could be caused by strong gradients (sensitivity) up until an intermediate input value (in between the two input values). In this case, the highlighted component would be important for the task “in between” (represented by different counterfactual inputs) instead of the target task.

Overestimation

To test this, we can visualise the gradients for intervals which are summed up by IG. We focus on attention head (9, 6) because it is highlighted more strongly by IG than by AP.

Visualising outputs

Here, we visualise the outputs with respect to interpolated layer inputs to see how the shape of the output curves might change when attribution scores are different.

Code
from captum.attr._utils.approximation_methods import approximation_parameters

def visualise_attn_interpolated_outputs(target_layer_num, target_pos):
    hook_name = get_act_name("result", target_layer_num)
    target_layer = model.hook_dict[hook_name]

    layer_clean_input = clean_cache[hook_name] # Baseline

    # Only corrupt at target head
    layer_corrupt_input = layer_clean_input.clone()
    layer_corrupt_input[:, :, target_pos] = corrupted_cache[hook_name][:, :, target_pos]

    # Take the model starting from the target layer
    forward_fn = lambda x: run_from_layer_fn(x, clean_input, target_layer)
    _, alphas_func = approximation_parameters("gausslegendre")
    alphas = alphas_func(n_steps)

    with torch.autograd.set_grad_enabled(True):
        interpolated_inputs = [layer_clean_input + alpha * (layer_corrupt_input - layer_clean_input) for alpha in alphas]
        outputs = [forward_fn(i) for i in interpolated_inputs]

    print(outputs)

    plt.title(f"Model output at interpolated gradients: head {(target_layer_num, target_pos)}")
    plt.plot([o.item() for o in outputs])
    plt.xlabel("Interpolation coefficient")
    plt.ylabel("Output (logit difference)")
    plt.ylim(0, 6)
    plt.show()
Code
# Highlighted only by IG

visualise_attn_interpolated_outputs(9, 6)

Code
visualise_attn_interpolated_outputs(0, 2)

The attention head for which only integrated gradients highlights (weakly) has a very low gradient, i.e. the outputs do not change significantly when the input is interpolated. This is in contrast to other attention heads which have stronger attribution scores in both methods; these methods have larger gradients. The change in gradients is still larger than attention heads which have negligible attribution scores (these have flat output gradients).

It may just be that IG is more error-prone as it over-estimates the importance of components for which outputs fluctuate slightly.

Code
# Highlighted by both, strong impact on performance under ablation

visualise_attn_interpolated_outputs(7, 3)
visualise_attn_interpolated_outputs(10, 6)

Code
# Highlighted by both, positive effect on performance under ablation

visualise_attn_interpolated_outputs(11, 10)

Code
# Highlighted in both methods, lack of ablation effect

visualise_attn_interpolated_outputs(9, 9)
visualise_attn_interpolated_outputs(8, 10)

Code
# Low attribution scores in both methods

visualise_attn_interpolated_outputs(2, 5)
visualise_attn_interpolated_outputs(10, 3)
[tensor([4.2764], grad_fn=<UnsqueezeBackward0>), tensor([4.2765], grad_fn=<UnsqueezeBackward0>), tensor([4.2765], grad_fn=<UnsqueezeBackward0>), tensor([4.2766], grad_fn=<UnsqueezeBackward0>), tensor([4.2766], grad_fn=<UnsqueezeBackward0>), tensor([4.2767], grad_fn=<UnsqueezeBackward0>), tensor([4.2768], grad_fn=<UnsqueezeBackward0>), tensor([4.2770], grad_fn=<UnsqueezeBackward0>), tensor([4.2771], grad_fn=<UnsqueezeBackward0>), tensor([4.2772], grad_fn=<UnsqueezeBackward0>), tensor([4.2774], grad_fn=<UnsqueezeBackward0>), tensor([4.2776], grad_fn=<UnsqueezeBackward0>), tensor([4.2778], grad_fn=<UnsqueezeBackward0>), tensor([4.2780], grad_fn=<UnsqueezeBackward0>), tensor([4.2782], grad_fn=<UnsqueezeBackward0>), tensor([4.2784], grad_fn=<UnsqueezeBackward0>), tensor([4.2787], grad_fn=<UnsqueezeBackward0>), tensor([4.2789], grad_fn=<UnsqueezeBackward0>), tensor([4.2792], grad_fn=<UnsqueezeBackward0>), tensor([4.2795], grad_fn=<UnsqueezeBackward0>), tensor([4.2797], grad_fn=<UnsqueezeBackward0>), tensor([4.2800], grad_fn=<UnsqueezeBackward0>), tensor([4.2803], grad_fn=<UnsqueezeBackward0>), tensor([4.2806], grad_fn=<UnsqueezeBackward0>), tensor([4.2809], grad_fn=<UnsqueezeBackward0>), tensor([4.2812], grad_fn=<UnsqueezeBackward0>), tensor([4.2814], grad_fn=<UnsqueezeBackward0>), tensor([4.2817], grad_fn=<UnsqueezeBackward0>), tensor([4.2820], grad_fn=<UnsqueezeBackward0>), tensor([4.2823], grad_fn=<UnsqueezeBackward0>), tensor([4.2826], grad_fn=<UnsqueezeBackward0>), tensor([4.2829], grad_fn=<UnsqueezeBackward0>), tensor([4.2831], grad_fn=<UnsqueezeBackward0>), tensor([4.2834], grad_fn=<UnsqueezeBackward0>), tensor([4.2836], grad_fn=<UnsqueezeBackward0>), tensor([4.2839], grad_fn=<UnsqueezeBackward0>), tensor([4.2841], grad_fn=<UnsqueezeBackward0>), tensor([4.2843], grad_fn=<UnsqueezeBackward0>), tensor([4.2845], grad_fn=<UnsqueezeBackward0>), tensor([4.2847], grad_fn=<UnsqueezeBackward0>), tensor([4.2848], grad_fn=<UnsqueezeBackward0>), tensor([4.2850], grad_fn=<UnsqueezeBackward0>), tensor([4.2851], grad_fn=<UnsqueezeBackward0>), tensor([4.2853], grad_fn=<UnsqueezeBackward0>), tensor([4.2854], grad_fn=<UnsqueezeBackward0>), tensor([4.2855], grad_fn=<UnsqueezeBackward0>), tensor([4.2856], grad_fn=<UnsqueezeBackward0>), tensor([4.2856], grad_fn=<UnsqueezeBackward0>), tensor([4.2857], grad_fn=<UnsqueezeBackward0>), tensor([4.2857], grad_fn=<UnsqueezeBackward0>)]

[tensor([4.2764], grad_fn=<UnsqueezeBackward0>), tensor([4.2764], grad_fn=<UnsqueezeBackward0>), tensor([4.2763], grad_fn=<UnsqueezeBackward0>), tensor([4.2762], grad_fn=<UnsqueezeBackward0>), tensor([4.2760], grad_fn=<UnsqueezeBackward0>), tensor([4.2758], grad_fn=<UnsqueezeBackward0>), tensor([4.2756], grad_fn=<UnsqueezeBackward0>), tensor([4.2754], grad_fn=<UnsqueezeBackward0>), tensor([4.2751], grad_fn=<UnsqueezeBackward0>), tensor([4.2748], grad_fn=<UnsqueezeBackward0>), tensor([4.2744], grad_fn=<UnsqueezeBackward0>), tensor([4.2741], grad_fn=<UnsqueezeBackward0>), tensor([4.2737], grad_fn=<UnsqueezeBackward0>), tensor([4.2732], grad_fn=<UnsqueezeBackward0>), tensor([4.2728], grad_fn=<UnsqueezeBackward0>), tensor([4.2723], grad_fn=<UnsqueezeBackward0>), tensor([4.2718], grad_fn=<UnsqueezeBackward0>), tensor([4.2713], grad_fn=<UnsqueezeBackward0>), tensor([4.2708], grad_fn=<UnsqueezeBackward0>), tensor([4.2703], grad_fn=<UnsqueezeBackward0>), tensor([4.2697], grad_fn=<UnsqueezeBackward0>), tensor([4.2691], grad_fn=<UnsqueezeBackward0>), tensor([4.2686], grad_fn=<UnsqueezeBackward0>), tensor([4.2680], grad_fn=<UnsqueezeBackward0>), tensor([4.2674], grad_fn=<UnsqueezeBackward0>), tensor([4.2668], grad_fn=<UnsqueezeBackward0>), tensor([4.2663], grad_fn=<UnsqueezeBackward0>), tensor([4.2657], grad_fn=<UnsqueezeBackward0>), tensor([4.2651], grad_fn=<UnsqueezeBackward0>), tensor([4.2646], grad_fn=<UnsqueezeBackward0>), tensor([4.2640], grad_fn=<UnsqueezeBackward0>), tensor([4.2635], grad_fn=<UnsqueezeBackward0>), tensor([4.2629], grad_fn=<UnsqueezeBackward0>), tensor([4.2624], grad_fn=<UnsqueezeBackward0>), tensor([4.2620], grad_fn=<UnsqueezeBackward0>), tensor([4.2615], grad_fn=<UnsqueezeBackward0>), tensor([4.2610], grad_fn=<UnsqueezeBackward0>), tensor([4.2606], grad_fn=<UnsqueezeBackward0>), tensor([4.2602], grad_fn=<UnsqueezeBackward0>), tensor([4.2598], grad_fn=<UnsqueezeBackward0>), tensor([4.2595], grad_fn=<UnsqueezeBackward0>), tensor([4.2592], grad_fn=<UnsqueezeBackward0>), tensor([4.2589], grad_fn=<UnsqueezeBackward0>), tensor([4.2587], grad_fn=<UnsqueezeBackward0>), tensor([4.2584], grad_fn=<UnsqueezeBackward0>), tensor([4.2582], grad_fn=<UnsqueezeBackward0>), tensor([4.2581], grad_fn=<UnsqueezeBackward0>), tensor([4.2580], grad_fn=<UnsqueezeBackward0>), tensor([4.2579], grad_fn=<UnsqueezeBackward0>), tensor([4.2578], grad_fn=<UnsqueezeBackward0>)]

So the attention heads with the attribution scores tend to have peaks and troughs at the input counterfactual and baseline counterfactual respectively. This begs the question - can we visualise how the output changes over a much wider range of interpolation? Are there peaks and troughs outside of this range for the other attention heads, for which IG would higlight given the optimal counterfactual pairs?

  • IDEA: identify the peaks and troughs for the outputs wrt a specific component. Fix the baseline; vary the activation of the specific component up to a different input and check the outputs.

“Optimal” contrastive pairs

Consider head (9, 6), which according to the IOI paper is a name mover head:

“Name Mover Heads output the remaining name. They are active at END, attend to previous names in the sentence, and copy the names they attend to”.

We use a different contrastive pair related to the IOI task, to try and get a high attribution score under IG. We change the corrupted prompt such that 1) the output should change from “John” to “Mary”, and 2) the name copying head (hypothesised role of head 9.6) is even more important.

Code
clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After the cat and the dog went to the store, Mary gave a bottle of milk to"

clean_input, corrupted_input = model.to_tokens([clean_prompt, corrupted_prompt])

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

clean_logits, clean_cache = model.run_with_cache(clean_input)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_input)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")
Clean logit difference: 1.458
Corrupted logit difference: -0.274
Code
# Gradient attribution on heads
hook_name = get_act_name("result", 9)
target_layer = model.hook_dict[hook_name]
prev_layer_hook = get_act_name("z", 9)
prev_layer = model.hook_dict[prev_layer_hook]

layer_clean_input = clean_cache[prev_layer_hook]
layer_corrupt_input = corrupted_cache[prev_layer_hook]

attributions = compute_layer_to_output_attributions(clean_input, layer_clean_input, layer_corrupt_input, target_layer, prev_layer) # shape [1, seq_len, d_head, d_model]
# Calculate attribution score based on mean over each embedding, for each token
per_token_score = attributions.mean(dim=3)
score = per_token_score.mean(dim=1)

print(score[:,6])

Error (delta) for blocks.9.attn.hook_result attribution: -0.9106521606445312
tensor([3.3620e-08])
Code
# Original scores

print(attn_ig_results[9, 6])
print(attn_patch_results[9, 6])
tensor(-4.5160e-07)
tensor(0.0089)
Code
visualise_attn_interpolated_outputs(9, 6)
[tensor([1.4576], grad_fn=<UnsqueezeBackward0>), tensor([1.4568], grad_fn=<UnsqueezeBackward0>), tensor([1.4553], grad_fn=<UnsqueezeBackward0>), tensor([1.4532], grad_fn=<UnsqueezeBackward0>), tensor([1.4504], grad_fn=<UnsqueezeBackward0>), tensor([1.4471], grad_fn=<UnsqueezeBackward0>), tensor([1.4431], grad_fn=<UnsqueezeBackward0>), tensor([1.4385], grad_fn=<UnsqueezeBackward0>), tensor([1.4333], grad_fn=<UnsqueezeBackward0>), tensor([1.4276], grad_fn=<UnsqueezeBackward0>), tensor([1.4213], grad_fn=<UnsqueezeBackward0>), tensor([1.4146], grad_fn=<UnsqueezeBackward0>), tensor([1.4073], grad_fn=<UnsqueezeBackward0>), tensor([1.3996], grad_fn=<UnsqueezeBackward0>), tensor([1.3914], grad_fn=<UnsqueezeBackward0>), tensor([1.3828], grad_fn=<UnsqueezeBackward0>), tensor([1.3739], grad_fn=<UnsqueezeBackward0>), tensor([1.3646], grad_fn=<UnsqueezeBackward0>), tensor([1.3551], grad_fn=<UnsqueezeBackward0>), tensor([1.3453], grad_fn=<UnsqueezeBackward0>), tensor([1.3352], grad_fn=<UnsqueezeBackward0>), tensor([1.3250], grad_fn=<UnsqueezeBackward0>), tensor([1.3146], grad_fn=<UnsqueezeBackward0>), tensor([1.3041], grad_fn=<UnsqueezeBackward0>), tensor([1.2936], grad_fn=<UnsqueezeBackward0>), tensor([1.2830], grad_fn=<UnsqueezeBackward0>), tensor([1.2725], grad_fn=<UnsqueezeBackward0>), tensor([1.2620], grad_fn=<UnsqueezeBackward0>), tensor([1.2516], grad_fn=<UnsqueezeBackward0>), tensor([1.2413], grad_fn=<UnsqueezeBackward0>), tensor([1.2313], grad_fn=<UnsqueezeBackward0>), tensor([1.2214], grad_fn=<UnsqueezeBackward0>), tensor([1.2118], grad_fn=<UnsqueezeBackward0>), tensor([1.2025], grad_fn=<UnsqueezeBackward0>), tensor([1.1936], grad_fn=<UnsqueezeBackward0>), tensor([1.1849], grad_fn=<UnsqueezeBackward0>), tensor([1.1767], grad_fn=<UnsqueezeBackward0>), tensor([1.1690], grad_fn=<UnsqueezeBackward0>), tensor([1.1617], grad_fn=<UnsqueezeBackward0>), tensor([1.1548], grad_fn=<UnsqueezeBackward0>), tensor([1.1485], grad_fn=<UnsqueezeBackward0>), tensor([1.1428], grad_fn=<UnsqueezeBackward0>), tensor([1.1376], grad_fn=<UnsqueezeBackward0>), tensor([1.1329], grad_fn=<UnsqueezeBackward0>), tensor([1.1289], grad_fn=<UnsqueezeBackward0>), tensor([1.1255], grad_fn=<UnsqueezeBackward0>), tensor([1.1227], grad_fn=<UnsqueezeBackward0>), tensor([1.1206], grad_fn=<UnsqueezeBackward0>), tensor([1.1191], grad_fn=<UnsqueezeBackward0>), tensor([1.1183], grad_fn=<UnsqueezeBackward0>)]

Code
# Get activation patching scores

hook_name = get_act_name("result", 9)
temp_hook = lambda act, hook: patch_attn_hook(act, hook, corrupted_cache, 6)

with model.hooks(fwd_hooks=[(hook_name, temp_hook)]):
    patched_logits = model(clean_input)

patched_logit_diff = logits_to_logit_diff(patched_logits).detach()
# Normalise result by clean and corrupted logit difference
ap_score = (patched_logit_diff - clean_logit_diff) / baseline_diff
print(ap_score)
tensor(-0.0456)

Changing the baseline inputs such that the output gradients vary more doesn’t necessarily seem to affect IG too much, but it does increase the magnitude of the activation patching score.

It seems clear that the baselines used for attribution methods are extremely important hyper-parameters, but there is no clear intuition as to which baseline is “best” for evaluating specific model behaviours. This provides motivation for a new method which identifies the optimal counterfactuals to make attribution methods highlight specific components.